{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wJcYs_ERTnnI"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "HMUDt0CiUJk9"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77z2OchJTk0l"
      },
      "source": [
        "# Migrate LoggingTensorHook and StopAtStepHook to Keras callbacks\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/migrate/logging_stop_hook\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/migrate/logging_stop_hook.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "meUTrR4I6m1C"
      },
      "source": [
        "In TensorFlow 1, you use `tf.estimator.LoggingTensorHook` to monitor and log tensors, while `tf.estimator.StopAtStepHook` helps stop training at a specified step when training with `tf.estimator.Estimator`. This notebook demonstrates how to migrate from these APIs to their equivalents in TensorFlow 2 using custom Keras callbacks (`tf.keras.callbacks.Callback`) with `Model.fit`.\n",
        "\n",
        "Keras [callbacks](https://www.tensorflow.org/guide/keras/custom_callback) are objects that are called at different points during training/evaluation/prediction in the built-in Keras `Model.fit`/`Model.evaluate`/`Model.predict` APIs. You can learn more about callbacks in the `tf.keras.callbacks.Callback` API docs, as well as the [Writing your own callbacks](../..guide/keras/custom_callback.ipynb/) and [Training and evaluation with the built-in methods](https://www.tensorflow.org/guide/keras/train_and_evaluate) (the *Using callbacks* section) guides. For migrating from `SessionRunHook` in TensorFlow 1 to Keras callbacks in TensorFlow 2, check out the [Migrate training with assisted logic](sessionrunhook_callback.ipynb) guide."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YdZSoIXEbhg-"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Start with imports and a simple dataset for demonstration purposes:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iE0vSfMXumKI"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow.compat.v1 as tf1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m7rnGxsXtDkV"
      },
      "outputs": [],
      "source": [
        "features = [[1., 1.5], [2., 2.5], [3., 3.5]]\n",
        "labels = [[0.3], [0.5], [0.7]]\n",
        "\n",
        "# Define an input function.\n",
        "def _input_fn():\n",
        "  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4uXff1BEssdE"
      },
      "source": [
        "## TensorFlow 1: Log tensors and stop training with tf.estimator APIs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zW-X5cmzmkuw"
      },
      "source": [
        "In TensorFlow 1, you define various hooks to control the training behavior. Then, you pass these hooks to `tf.estimator.EstimatorSpec`.\n",
        "\n",
        "In the example below:\n",
        "\n",
        "- To monitor/log tensors—for example, model weights or losses—you use `tf.estimator.LoggingTensorHook` (`tf.train.LoggingTensorHook` is its alias).\n",
        "- To stop training at a specific step, you use `tf.estimator.StopAtStepHook` (`tf.train.StopAtStepHook` is its alias)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lqe9obf7suIj"
      },
      "outputs": [],
      "source": [
        "def _model_fn(features, labels, mode):\n",
        "  dense = tf1.layers.Dense(1)\n",
        "  logits = dense(features)\n",
        "  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)\n",
        "  optimizer = tf1.train.AdagradOptimizer(0.05)\n",
        "  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n",
        "\n",
        "  # Define the stop hook.\n",
        "  stop_hook = tf1.train.StopAtStepHook(num_steps=2)\n",
        "\n",
        "  # Access tensors to be logged by names.\n",
        "  kernel_name = tf.identity(dense.weights[0])\n",
        "  bias_name = tf.identity(dense.weights[1])\n",
        "  logging_weight_hook = tf1.train.LoggingTensorHook(\n",
        "      tensors=[kernel_name, bias_name],\n",
        "      every_n_iter=1)\n",
        "  # Log the training loss by the tensor object.\n",
        "  logging_loss_hook = tf1.train.LoggingTensorHook(\n",
        "      {'loss from LoggingTensorHook': loss},\n",
        "      every_n_secs=3)\n",
        "\n",
        "  # Pass all hooks to `EstimatorSpec`.\n",
        "  return tf1.estimator.EstimatorSpec(mode,\n",
        "                                     loss=loss,\n",
        "                                     train_op=train_op,\n",
        "                                     training_hooks=[stop_hook,\n",
        "                                                     logging_weight_hook,\n",
        "                                                     logging_loss_hook])\n",
        "\n",
        "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n",
        "\n",
        "# Begin training.\n",
        "# The training will stop after 2 steps, and the weights/loss will also be logged.\n",
        "estimator.train(_input_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KEmzBjfnsxwT"
      },
      "source": [
        "## TensorFlow 2: Log tensors and stop training with custom callbacks and Model.fit"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "839R9i4xheI5"
      },
      "source": [
        "In TensorFlow 2, when you use the built-in Keras `Model.fit` (or `Model.evaluate`) for training/evaluation, you can configure tensor monitoring and training stopping by defining custom Keras `tf.keras.callbacks.Callback`s. Then, you pass them to the `callbacks` parameter of `Model.fit` (or `Model.evaluate`). (Learn more in the [Writing your own callbacks](../..guide/keras/custom_callback.ipynb) guide.)\n",
        "\n",
        "In the example below:\n",
        "\n",
        "- To recreate the functionalities of `StopAtStepHook`, define a custom callback (named `StopAtStepCallback` below) where you override the `on_batch_end` method to stop training after a certain number of steps.\n",
        "- To recreate the `LoggingTensorHook` behavior, define a custom callback (`LoggingTensorCallback`) where you record and output the logged tensors manually, since accessing to tensors by names is not supported. You can also implement the logging frequency inside the custom callback. The example below will print the weights every two steps. Other strategies like logging every N seconds are also possible."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "atVciNgPs0fw"
      },
      "outputs": [],
      "source": [
        "class StopAtStepCallback(tf.keras.callbacks.Callback):\n",
        "  def __init__(self, stop_step=None):\n",
        "    super().__init__()\n",
        "    self._stop_step = stop_step\n",
        "\n",
        "  def on_batch_end(self, batch, logs=None):\n",
        "    if self.model.optimizer.iterations >= self._stop_step:\n",
        "      self.model.stop_training = True\n",
        "      print('\\nstop training now')\n",
        "\n",
        "class LoggingTensorCallback(tf.keras.callbacks.Callback):\n",
        "  def __init__(self, every_n_iter):\n",
        "      super().__init__()\n",
        "      self._every_n_iter = every_n_iter\n",
        "      self._log_count = every_n_iter\n",
        "\n",
        "  def on_batch_end(self, batch, logs=None):\n",
        "    if self._log_count > 0:\n",
        "      self._log_count -= 1\n",
        "      print(\"Logging Tensor Callback: dense/kernel:\",\n",
        "            model.layers[0].weights[0])\n",
        "      print(\"Logging Tensor Callback: dense/bias:\",\n",
        "            model.layers[0].weights[1])\n",
        "      print(\"Logging Tensor Callback loss:\", logs[\"loss\"])\n",
        "    else:\n",
        "      self._log_count -= self._every_n_iter"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "30a8b71263e0"
      },
      "source": [
        "When finished, pass the new callbacks—`StopAtStepCallback` and `LoggingTensorCallback`—to the `callbacks` parameter of `Model.fit`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kip65sYBlKiu"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)\n",
        "model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])\n",
        "optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)\n",
        "model.compile(optimizer, \"mse\")\n",
        "\n",
        "# Begin training.\n",
        "# The training will stop after 2 steps, and the weights/loss will also be logged.\n",
        "model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),\n",
        "                              LoggingTensorCallback(every_n_iter=2)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "19508f4720f5"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "Learn more about callbacks in:\n",
        "\n",
        "- API docs: `tf.keras.callbacks.Callback`\n",
        "- Guide: [Writing your own callbacks](../..guide/keras/custom_callback.ipynb/)\n",
        "- Guide: [Training and evaluation with the built-in methods](https://www.tensorflow.org/guide/keras/train_and_evaluate) (the *Using callbacks* section)\n",
        "\n",
        "You may also find the following migration-related resources useful:\n",
        "\n",
        "- The [Early stopping migration guide](early_stopping.ipynb): `tf.keras.callbacks.EarlyStopping` is a built-in early stopping callback\n",
        "- The [TensorBoard migration guide](tensorboard.ipynb): TensorBoard enables tracking and displaying metrics\n",
        "- The [Training with assisted logic migration guide](sessionrunhook_callback.ipynb): From `SessionRunHook` in TensorFlow 1 to Keras callbacks in TensorFlow 2"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "logging_stop_hook.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
